/*
 * Copyright 2024 - 2024 the original author or authors.
 */
package com.kuafuai.manage.mcp.server.transport;

import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.kuafuai.common.util.StringUtils;
import com.kuafuai.manage.mcp.spec.*;
import io.jsonwebtoken.lang.Strings;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

import javax.servlet.AsyncContext;
import javax.servlet.AsyncEvent;
import javax.servlet.AsyncListener;
import javax.servlet.ServletException;
import javax.servlet.annotation.WebServlet;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicBoolean;

/**
 * A Servlet-based implementation of the MCP HTTP with Server-Sent Events (SSE) transport
 * specification. This implementation provides similar functionality to
 * WebFluxSseServerTransportProvider but uses the traditional Servlet API instead of
 * WebFlux.
 *
 * <p>
 * The transport handles two types of endpoints:
 * <ul>
 * <li>SSE endpoint (/sse) - Establishes a long-lived connection for server-to-client
 * events</li>
 * <li>Message endpoint (configurable) - Handles client-to-server message requests</li>
 * </ul>
 *
 * <p>
 * Features:
 * <ul>
 * <li>Asynchronous message handling using Servlet 6.0 async support</li>
 * <li>Session management for multiple client connections</li>
 * <li>Graceful shutdown support</li>
 * <li>Error handling and response formatting</li>
 * </ul>
 *
 * @author Christian Tzolov
 * @author Alexandros Pappas
 * @see McpServerTransportProvider
 * @see HttpServlet
 */

@WebServlet(asyncSupported = true)
public class HttpServletSseServerTransportProvider extends HttpServlet implements McpServerTransportProvider {

    /**
     * Logger for this class
     */
    private static final Logger logger = LoggerFactory.getLogger(HttpServletSseServerTransportProvider.class);

    public static final String UTF_8 = "UTF-8";

    public static final String APPLICATION_JSON = "application/json";

    public static final String FAILED_TO_SEND_ERROR_RESPONSE = "Failed to send error response: {}";

    /**
     * Default endpoint path for SSE connections
     */
    public static final String DEFAULT_SSE_ENDPOINT = "/sse";

    /**
     * Event type for regular messages
     */
    public static final String MESSAGE_EVENT_TYPE = "message";

    /**
     * Event type for endpoint information
     */
    public static final String ENDPOINT_EVENT_TYPE = "endpoint";

    /**
     * JSON object mapper for serialization/deserialization
     */
    private final ObjectMapper objectMapper;

    /**
     * The endpoint path for handling client messages
     */
    private final String messageEndpoint;

    /**
     * The endpoint path for handling SSE connections
     */
    private final String sseEndpoint;

    /**
     * Map of active client sessions, keyed by session ID
     */
    private final Map<String, McpServerSession> sessions = new ConcurrentHashMap<>();

    /**
     * Flag indicating if the transport is in the process of shutting down
     */
    private final AtomicBoolean isClosing = new AtomicBoolean(false);

    /**
     * Session factory for creating new sessions
     */
    private McpServerSession.Factory sessionFactory;

    /**
     * Creates a new HttpServletSseServerTransportProvider instance with a custom SSE
     * endpoint.
     *
     * @param objectMapper    The JSON object mapper to use for message
     *                        serialization/deserialization
     * @param messageEndpoint The endpoint path where clients will send their messages
     * @param sseEndpoint     The endpoint path where clients will establish SSE connections
     */
    public HttpServletSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint,
                                                 String sseEndpoint) {
        this.objectMapper = objectMapper;
        this.messageEndpoint = messageEndpoint;
        this.sseEndpoint = sseEndpoint;
    }

    /**
     * Creates a new HttpServletSseServerTransportProvider instance with the default SSE
     * endpoint.
     *
     * @param objectMapper    The JSON object mapper to use for message
     *                        serialization/deserialization
     * @param messageEndpoint The endpoint path where clients will send their messages
     */
    public HttpServletSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint) {
        this(objectMapper, messageEndpoint, DEFAULT_SSE_ENDPOINT);
    }

    /**
     * Sets the session factory for creating new sessions.
     *
     * @param sessionFactory The session factory to use
     */
    @Override
    public void setSessionFactory(McpServerSession.Factory sessionFactory) {
        this.sessionFactory = sessionFactory;
    }

    /**
     * Broadcasts a notification to all connected clients.
     *
     * @param method The method name for the notification
     * @param params The parameters for the notification
     * @return A Mono that completes when the broadcast attempt is finished
     */
    @Override
    public Mono<Void> notifyClients(String method, Map<String, Object> params) {
        if (sessions.isEmpty()) {
            logger.debug("No active sessions to broadcast message to");
            return Mono.empty();
        }

        logger.debug("Attempting to broadcast message to {} active sessions", sessions.size());

        return Flux.fromIterable(sessions.values())
                .flatMap(session -> session.sendNotification(method, params)
                        .doOnError(
                                e -> logger.error("Failed to send message to session {}: {}", session.getId(), e.getMessage()))
                        .onErrorComplete())
                .then();
    }

    /**
     * Handles GET requests to establish SSE connections.
     * <p>
     * This method sets up a new SSE connection when a client connects to the SSE
     * endpoint. It configures the response headers for SSE, creates a new session, and
     * sends the initial endpoint information to the client.
     *
     * @param request  The HTTP servlet request
     * @param response The HTTP servlet response
     * @throws IOException If an I/O error occurs
     */
    @Override
    protected void doGet(HttpServletRequest request, HttpServletResponse response)
            throws IOException {

        String pathInfo = request.getRequestURI();
        if (!sseEndpoint.equals(pathInfo)) {
            response.sendError(HttpServletResponse.SC_NOT_FOUND);
            return;
        }

        String token = request.getParameter("token");
        if (StringUtils.isEmpty(Strings.trimAllWhitespace(token))) {
            response.sendError(HttpServletResponse.SC_UNAUTHORIZED);
            return;
        }

        if (isClosing.get()) {
            response.sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE, "Server is shutting down");
            return;
        }

        response.setContentType("text/event-stream");
        response.setCharacterEncoding(UTF_8);
        response.setHeader("Cache-Control", "no-cache");
        response.setHeader("Connection", "keep-alive");
        response.setHeader("Access-Control-Allow-Origin", "*");

        String sessionId = UUID.randomUUID().toString();
        AsyncContext asyncContext = request.startAsync();
        asyncContext.setTimeout(300000); //5分钟超时

        asyncContext.addListener(new AsyncListener() {
            @Override
            public void onComplete(AsyncEvent event) throws IOException {
                logger.debug("AsyncContext completed for session {}", sessionId);
                sessions.remove(sessionId);
            }

            @Override
            public void onTimeout(AsyncEvent event) throws IOException {
                logger.warn("Session {} timeout, cleaning up", sessionId);
                sessions.remove(sessionId);
                try {
                    event.getAsyncContext().complete();
                } catch (Exception e) {
                    logger.error("Error completing async context on timeout", e);
                }
            }

            @Override
            public void onError(AsyncEvent event) throws IOException {
                logger.error("AsyncContext error for session {}: {}", sessionId, event.getThrowable().getMessage());
                sessions.remove(sessionId);
                try {
                    event.getAsyncContext().complete();
                } catch (Exception e) {
                    logger.error("Error completing async context on error", e);
                }
            }

            @Override
            public void onStartAsync(AsyncEvent event) throws IOException {

            }
        });

        PrintWriter writer = response.getWriter();

        // Create a new session transport
        HttpServletMcpSessionTransport sessionTransport = new HttpServletMcpSessionTransport(sessionId, asyncContext,
                writer);

        // Create a new session using the session factory
        McpServerSession session = sessionFactory.create(sessionTransport);
        this.sessions.put(sessionId, session);

        // Send initial endpoint event
        this.sendEvent(writer, ENDPOINT_EVENT_TYPE, messageEndpoint + "?sessionId=" + sessionId + "&token=" + token);
    }

    /**
     * Handles POST requests for client messages.
     * <p>
     * This method processes incoming messages from clients, routes them through the
     * session handler, and sends back the appropriate response. It handles error cases
     * and formats error responses according to the MCP specification.
     *
     * @param request  The HTTP servlet request
     * @param response The HTTP servlet response
     * @throws ServletException If a servlet-specific error occurs
     * @throws IOException      If an I/O error occurs
     */
    @Override
    protected void doPost(HttpServletRequest request, HttpServletResponse response)
            throws ServletException, IOException {

        if (isClosing.get()) {
            response.sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE, "Server is shutting down");
            return;
        }

        String pathInfo = request.getRequestURI();
        if (!messageEndpoint.equals(pathInfo)) {
            response.sendError(HttpServletResponse.SC_NOT_FOUND);
            return;
        }

        // Get the session ID from the request parameter
        String sessionId = request.getParameter("sessionId");
        if (sessionId == null) {
            response.setContentType(APPLICATION_JSON);
            response.setCharacterEncoding(UTF_8);
            response.setStatus(HttpServletResponse.SC_BAD_REQUEST);
            String jsonError = objectMapper.writeValueAsString(new McpError("Session ID missing in message endpoint"));
            PrintWriter writer = response.getWriter();
            writer.write(jsonError);
            writer.flush();
            return;
        }

        // Get the session from the sessions map
        McpServerSession session = sessions.get(sessionId);
        if (session == null) {
            response.setContentType(APPLICATION_JSON);
            response.setCharacterEncoding(UTF_8);
            response.setStatus(HttpServletResponse.SC_NOT_FOUND);
            String jsonError = objectMapper.writeValueAsString(new McpError("Session not found: " + sessionId));
            PrintWriter writer = response.getWriter();
            writer.write(jsonError);
            writer.flush();
            return;
        }

        try {
            BufferedReader reader = request.getReader();
            StringBuilder body = new StringBuilder();
            String line;
            while ((line = reader.readLine()) != null) {
                body.append(line);
            }

            McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, body.toString());

            String token = request.getParameter("token");
            // Process the message through the session's handle method
            session.handle(message, token).block(); // Block for Servlet compatibility

            response.setStatus(HttpServletResponse.SC_OK);
        } catch (Exception e) {
            logger.error("Error processing message: {}", e.getMessage());
            try {
                McpError mcpError = new McpError(e.getMessage());
                response.setContentType(APPLICATION_JSON);
                response.setCharacterEncoding(UTF_8);
                response.setStatus(HttpServletResponse.SC_INTERNAL_SERVER_ERROR);
                String jsonError = objectMapper.writeValueAsString(mcpError);
                PrintWriter writer = response.getWriter();
                writer.write(jsonError);
                writer.flush();
            } catch (IOException ex) {
                logger.error(FAILED_TO_SEND_ERROR_RESPONSE, ex.getMessage());
                response.sendError(HttpServletResponse.SC_INTERNAL_SERVER_ERROR, "Error processing message");
            }
        }
    }

    /**
     * Initiates a graceful shutdown of the transport.
     * <p>
     * This method marks the transport as closing and closes all active client sessions.
     * New connection attempts will be rejected during shutdown.
     *
     * @return A Mono that completes when all sessions have been closed
     */
    @Override
    public Mono<Void> closeGracefully() {
        isClosing.set(true);
        logger.debug("Initiating graceful shutdown with {} active sessions", sessions.size());

        return Flux.fromIterable(sessions.values()).flatMap(McpServerSession::closeGracefully).then();
    }

    /**
     * Sends an SSE event to a client.
     *
     * @param writer    The writer to send the event through
     * @param eventType The type of event (message or endpoint)
     * @param data      The event data
     * @throws IOException If an error occurs while writing the event
     */
    private void sendEvent(PrintWriter writer, String eventType, String data) throws IOException {
        writer.write("event: " + eventType + "\n");
        writer.write("data: " + data + "\n\n");
        writer.flush();

        if (writer.checkError()) {
            throw new IOException("Client disconnected");
        }
    }

    /**
     * Cleans up resources when the servlet is being destroyed.
     * <p>
     * This method ensures a graceful shutdown by closing all client connections before
     * calling the parent's destroy method.
     */
    @Override
    public void destroy() {
        closeGracefully().block();
        super.destroy();
    }

    /**
     * Implementation of McpServerTransport for HttpServlet SSE sessions. This class
     * handles the transport-level communication for a specific client session.
     */
    private class HttpServletMcpSessionTransport implements McpServerTransport {

        private final String sessionId;

        private final AsyncContext asyncContext;

        private final PrintWriter writer;

        /**
         * Creates a new session transport with the specified ID and SSE writer.
         *
         * @param sessionId    The unique identifier for this session
         * @param asyncContext The async context for the session
         * @param writer       The writer for sending server events to the client
         */
        HttpServletMcpSessionTransport(String sessionId, AsyncContext asyncContext, PrintWriter writer) {
            this.sessionId = sessionId;
            this.asyncContext = asyncContext;
            this.writer = writer;
            logger.debug("Session transport {} initialized with SSE writer", sessionId);
        }

        /**
         * Sends a JSON-RPC message to the client through the SSE connection.
         *
         * @param message The JSON-RPC message to send
         * @return A Mono that completes when the message has been sent
         */
        @Override
        public Mono<Void> sendMessage(McpSchema.JSONRPCMessage message) {
            return Mono.fromRunnable(() -> {
                try {
                    String jsonText = objectMapper.writeValueAsString(message);
                    sendEvent(writer, MESSAGE_EVENT_TYPE, jsonText);
                    logger.debug("Message sent to session {}", sessionId);
                } catch (Exception e) {
                    logger.error("Failed to send message to session {}: {}", sessionId, e.getMessage());
                    sessions.remove(sessionId);
                    asyncContext.complete();
                }
            });
        }

        /**
         * Converts data from one type to another using the configured ObjectMapper.
         *
         * @param data    The source data object to convert
         * @param typeRef The target type reference
         * @param <T>     The target type
         * @return The converted object of type T
         */
        @Override
        public <T> T unmarshalFrom(Object data, TypeReference<T> typeRef) {
            return objectMapper.convertValue(data, typeRef);
        }

        /**
         * Initiates a graceful shutdown of the transport.
         *
         * @return A Mono that completes when the shutdown is complete
         */
        @Override
        public Mono<Void> closeGracefully() {
            return Mono.fromRunnable(() -> {
                logger.debug("Closing session transport: {}", sessionId);
                try {
                    sessions.remove(sessionId);
                    asyncContext.complete();
                    logger.debug("Successfully completed async context for session {}", sessionId);
                } catch (Exception e) {
                    logger.warn("Failed to complete async context for session {}: {}", sessionId, e.getMessage());
                }
            });
        }

        /**
         * Closes the transport immediately.
         */
        @Override
        public void close() {
            try {
                sessions.remove(sessionId);
                asyncContext.complete();
                logger.debug("Successfully completed async context for session {}", sessionId);
            } catch (Exception e) {
                logger.warn("Failed to complete async context for session {}: {}", sessionId, e.getMessage());
            }
        }

        @Override
        public AsyncContext getAsyncContext() {
            return asyncContext;
        }
    }

}
